# RAL implementations

load(
    "@org_tensorflow//tensorflow:tensorflow.bzl",
    "tf_cc_shared_object",
    "tf_cc_test",
    "tf_copts",
    "tf_gpu_kernel_library",
    "tf_gpu_library",
    "tf_native_cc_binary",
)
load(
    "@local_config_cuda//cuda:build_defs.bzl",
    "cuda_default_copts",
    "if_cuda",
    "if_cuda_is_configured",
    "cuda_library",
)
load("@local_config_rocm//rocm:build_defs.bzl", "if_dcu", "if_rocm_is_configured")
load(
    "@com_google_protobuf//:protobuf.bzl",
    "cc_proto_library",
)
load("//mlir/util:util.bzl",
     "disc_cc_library",
     "if_cuda_or_rocm",
     "if_platform_alibaba",
     "if_blade_gemm",
     "if_mkldnn",
     "if_disc_aarch64",
     "if_disc_x86",
     "if_torch_disc",
     "if_internal_serving",
     )

package(
    default_visibility=["//visibility:public"],
    licenses=["notice"],  # Apache 2.0
)

cc_library(
    name="ral_logging",
    srcs=[
        "ral_logging.cc",
    ],
    hdrs=[
        "ral_logging.h",
    ],
    deps=[
    ],
    alwayslink=1,
)

cc_library(
    name="ral_md5",
    srcs=[
        "ral_md5.cc",
    ],
    hdrs=[
        "ral_md5.h",
    ],
    deps=[
    ],
    alwayslink=1,
)

cc_library(
    name="ral_context",
    srcs=[
        "ral_context.cc",
        "ral_helper.cc",
    ],
    hdrs=[
        "ral_context.h",
        "ral_helper.h",
        "ral_driver.h",
        "ral_base.h",
    ],
    deps=[
        ":ral_logging",
        "@local_config_cuda//cuda:cuda_headers",
    ],
    alwayslink=1,
)

cc_library(
    name="ral_cpu_driver",
    srcs=["device/cpu/cpu_driver.cc"],
    hdrs=["device/cpu/cpu_driver.h"],
    deps=[
        ":ral_context",
        ":ral_logging"
    ],
    alwayslink=1,
)

cc_library(
    name="ral_gpu_driver",
    srcs=["device/gpu/gpu_driver.cc"],
    hdrs=["device/gpu/gpu_driver.h"],
    deps=[
        ":ral_context",
        ":ral_logging"
    ],
    alwayslink=1,
)

cc_library(
    name="ral_library",
    srcs=["ral_api.cc"],
    hdrs=["ral_api.h"],
    deps=[
        ":ral_context",
        ":ral_logging",
    ],
    alwayslink=1,
)

cc_library(
    name="ral_tf_context",
    srcs=[
        "context/tensorflow/tf_context_impl.cc",
        "context/tensorflow/tf_kernel_impl.cc",
    ],
    hdrs=["context/tensorflow/tf_context_impl.h"],
    deps=[
        ":context_util",
        ":common_context",
        ":ral_context",
        ":ral_cpu_driver",
        ":ral_gpu_driver",
        ":ral_library",
        "//tensorflow/core:framework",
        "//tensorflow/stream_executor",
        "//tensorflow/core:lib",
        "@local_config_cuda//cuda:cuda_headers",
    ],
    alwayslink=1,
)

cc_library(
    name="context_util",
    srcs=[],
    hdrs=["context/context_util.h"],
    deps=[
        ":ral_context",
        ":ral_logging"
    ],
    alwayslink=1,
)

cc_library(
    name="pdll_util",
    srcs=["context/pdll_util.cc"],
    hdrs=["context/pdll_util.h"],
    deps=[
        ":ral_context",
        ":ral_logging",
    ],
    alwayslink=1,
)

cc_library(
    name="ral_metadata",
    srcs=["ral_metadata.cc"],
    hdrs=["ral_metadata.h"],
    deps=[],
    alwayslink=1,
)

cc_library(
    name="ral_interface",
    deps=[
        ":ral_context",
        ":ral_logging",
        ":ral_library",
        ":ral_metadata",
    ] + if_cuda_or_rocm([
        ":ral_gpu_driver",
    ]),
    alwayslink=1,
)

tf_cc_test(
    name="ral_metadata_test",
    size="small",
    srcs=[
        "ral_metadata_test.cc",
    ],
    deps=[
        ":ral_metadata",
        "@org_tensorflow//tensorflow/core:test_main",
        "@org_tensorflow//tensorflow/core:test",
        "@org_tensorflow//tensorflow/core:testlib",
    ],
)

cc_library(
    name="blade_gemm_deps",
    deps=if_torch_disc(
        ["@blade_gemm//:blade_gemm"],
        ["//tao/blade_gemm:blade_gemm"],
    )
)

cc_library(
    name="disc_blade_gemm_deps",
    deps=if_platform_alibaba([
        ":blade_gemm_deps"
    ]),
    alwayslink=1,
)

cc_library(
    name="collective_ops",
    srcs=[
        "collective.cu.cc",
        #"context/pdll_util.cc",
    ],
    hdrs=[
        "collective.h",
        "context/pdll_util.h",
        "context/base/cuda/cuda_context_impl.h",
        "context/base/cpu/cpu_context_impl.h",
        "context/base/base_context.h",
        "context/stream_executor_based_impl.h",
    ],
    deps=[
        ":context_util",
        ":ral_context",
        ":ral_gpu_driver",
        ":ral_cpu_driver",
        ":ral_library",
        ":ral_logging",
    ] + if_cuda_or_rocm([
        "@local_config_nccl//:nccl",
    ]),
    copts=[
        "-DGOOGLE_CUDA",
    ]
)


# depend on mkldnn and onednn targets generated by cmake
cc_library(
    name="disc_mkl_cmake",
    hdrs=glob([
        "context/mkldnn/ideep/*.hpp",
        "context/mkldnn/ideep/ideep/*.hpp",
        "context/mkldnn/ideep/ideep/operators/*.hpp",
    ]),
    linkopts=[
        "-L$(location //tao/third_party/mkldnn:build)/install/lib64",
        "-L$(location //tao/third_party/mkldnn:build)/intel/lib",
        "-Wl,--start-group",
        "$(location //tao/third_party/mkldnn:build)/install/lib64/libdnnl.a",
        "-Wl,--end-group",
    ] + if_disc_x86([
        "-Wl,--start-group",
        "$(location //tao/third_party/mkldnn:build)/intel/lib/libmkl_intel_ilp64.a",
        "$(location //tao/third_party/mkldnn:build)/intel/lib/libmkl_gnu_thread.a",
        "$(location //tao/third_party/mkldnn:build)/intel/lib/libmkl_core.a",
        "-Wl,--end-group",
    ]) + if_disc_aarch64([
        "-Wl,--start-group",
        "$(location //tao/third_party/mkldnn:build)/acl/ComputeLibrary/build/libarm_compute.a",
        "$(location //tao/third_party/mkldnn:build)/acl/ComputeLibrary/build/libarm_compute_core.a",
        "$(location //tao/third_party/mkldnn:build)/acl/ComputeLibrary/build/libarm_compute_graph.a",
        "-Wl,--end-group",
    ]),
    data=["//tao/third_party/mkldnn:build"],
    deps=if_disc_x86([
        "//tao/third_party/mkldnn:onednn",
    ]) + if_disc_aarch64([
        "//tao/third_party/mkldnn:onednn_acl",
    ]),
)

# depend on mkldnn and onednn from bazel workspace
cc_library(
    name="disc_mkl_bazel",
    hdrs=glob([
        "context/mkldnn/ideep/*.hpp",
        "context/mkldnn/ideep/ideep/*.hpp",
        "context/mkldnn/ideep/ideep/operators/*.hpp",
    ]),
    linkopts=[
        "-Wl,--start-group",
        "$(locations @onednn//:onednn_lib)",
        "-Wl,--end-group",
        "-lgomp -lpthread -lm -ldl",
    ] + if_disc_x86([
        "-Wl,--start-group",
        "$(location @mkl_static//:lib)/libmkl_intel_ilp64.a",
        "$(location @mkl_static//:lib)/libmkl_gnu_thread.a",
        "$(location @mkl_static//:lib)/libmkl_core.a",
        "-Wl,--end-group",
    ]) + if_disc_aarch64([
        "-Wl,--start-group",
        "$(location @acl_compute_library//:build)/libarm_compute.a",
        "$(location @acl_compute_library//:build)/libarm_compute_core.a",
        "$(location @acl_compute_library//:build)/libarm_compute_graph.a",
        "-Wl,--end-group",
    ]),
    data=[
        "@onednn//:onednn_lib",
    ] + if_disc_x86([
        "@mkl_static//:lib",
    ]) + if_disc_aarch64([
        "@acl_compute_library//:build",
    ]),
    deps=[
        "@onednn//:onednn",  # acl is not needed in `deps` since the dependency is carried by `onednn`
    ] + if_disc_x86([
        "@mkl_include//:mkl_include",
    ]),
    alwayslink=1,
)

cc_library(
    name="common_context",
    srcs=[
        "context/common_context_impl.cc",
        "context/common_context_impl_pdll.cc",
    ] + if_cuda_or_rocm([
        "context/common_context_impl_cuda.cc",
        "context/stream_executor_based_impl.cc",
    ]) + if_cuda([
        "context/cuda_impl_sparse.cc",
        "context/quantized_gpu_kernel_impl.cc",
    ]) + if_mkldnn([
        "context/common_context_impl_mkldnn.cc",
        "context/common_context_impl_quantization.cc",
    ]) + if_disc_aarch64([
        "context/common_context_impl_acl.cc",
    ]) + if_blade_gemm([
        "context/stream_executor_expand_impl.cc",
    ]),
    hdrs=[
        "context/common_context_impl.h",
    ] + if_cuda_or_rocm([
        "context/stream_executor_based_impl.h",
    ]) + if_mkldnn([
        "context/common_context_impl_mkldnn.h",
        "context/common_context_impl_quantization.h",
    ]) + if_disc_aarch64([
        "context/common_context_impl_acl.h",
    ]),
    copts=[
        "-fopenmp",
    ] + if_dcu(["-DTENSORFLOW_USE_DCU=1"])
    + if_cuda_or_rocm(["-DTAO_RAL_USE_STREAM_EXECUTOR"]),
    linkopts=[
        "-fopenmp",
        "-ldl",
        "-lm"
    ] + if_blade_gemm([
        "-Wl,--start-group",
        "$(location @mkl_static//:lib)/libmkl_intel_ilp64.a",
        "$(location @mkl_static//:lib)/libmkl_gnu_thread.a",
        "$(location @mkl_static//:lib)/libmkl_core.a",
        "-Wl,--end-group",
    ]),
    deps=[
        ":pdll_util",
        ":ral_context",
        ":ral_cpu_driver",
        ":ral_logging",
        ":ral_metadata",
        ":context_util",
        "@com_google_absl//absl/strings",
        "@org_tensorflow//tensorflow/core/platform:status",
        "@org_tensorflow//tensorflow/core/platform:statusor",
    ] + if_cuda_is_configured([
        "@org_tensorflow//tensorflow/compiler/xla/stream_executor/cuda:all_runtime",
        "@local_config_cuda//cuda:cuda_driver",
        "@local_config_cuda//cuda:cuda_headers",
    ]) + if_rocm_is_configured([
        "@org_tensorflow//tensorflow/stream_executor:rocm_platform",
        "@org_tensorflow//tensorflow/stream_executor/rocm:rocm_driver",
        "@local_config_rocm//rocm:rocm_headers",
    ]) + if_cuda_or_rocm([
        ":ral_gpu_driver",
        "@org_tensorflow//tensorflow/core:lib",
        "@org_tensorflow//tensorflow/core:stream_executor_headers_lib",
        "@org_tensorflow//tensorflow/core:framework",
        "@org_tensorflow//tensorflow/core/kernels:numeric_options_utils",
    ]) + if_blade_gemm([
        ":disc_blade_gemm_deps",
    ]) + if_mkldnn([
        ":disc_mkl_cmake",
        ":ral_md5",
    ]),
    data=if_blade_gemm([
        "@mkl_static//:lib",
    ]),
    alwayslink=1,
)

cc_library(
    name="init_stream_executor",
    srcs=[
        "context/init_stream_executor.cc",
        "context/base/cuda/cuda_stream.cc"
    ],
    hdrs=[
        "context/init_stream_executor.h",
        "context/base/cuda/cuda_stream.h"
    ],
    deps=[
        ":common_context",
        ":ral_logging",
        "@org_tensorflow//tensorflow/core:lib",
        "@com_google_absl//absl/strings",
    ] + if_cuda_is_configured([
        "@org_tensorflow//tensorflow/compiler/xla/stream_executor/cuda:all_runtime",
        "@local_config_cuda//cuda:cuda_driver",
        "@local_config_cuda//cuda:cuda_headers",
    ]) + if_rocm_is_configured([
        "@org_tensorflow//tensorflow/stream_executor:rocm_platform",
    ]),
    alwayslink=1,
)

cc_library(
    name="ral_base_cpu_context_impl",
    srcs=[
        "context/base/cpu/cpu_context_impl.cc",
        "context/base/base_context.cc",
    ],
    hdrs=[
        "context/base/cpu/cpu_context_impl.h",
        "context/base/base_context.h",
    ] + if_cuda_or_rocm([
        "context/stream_executor_based_impl.h"
    ]),
    copts=if_cuda_or_rocm([
        "-DTAO_RAL_USE_STREAM_EXECUTOR",
    ]),
    deps=[
        ":context_util",
        ":common_context",
        ":ral_context",
        ":ral_cpu_driver",
        ":ral_library",
        ":ral_logging",
        "@com_google_absl//absl/strings",
        "@curl",
        "@org_tensorflow//third_party/eigen3",
    ] + if_cuda_or_rocm([
        "@org_tensorflow//tensorflow/core:stream_executor_headers_lib",
    ]) + if_blade_gemm([
        ":disc_blade_gemm_deps"
    ]),
    alwayslink=1,
    visibility=["//visibility:public"],
)

disc_cc_library(
    name="ral_base_cuda_context_impl",
    srcs=[
        "context/base/cuda/cuda_context_impl.cc",
    ],
    hdrs=[
        "context/base/cuda/cuda_context_impl.h",
    ],
    copts=["-DTAO_RAL_USE_STREAM_EXECUTOR"],
    deps=[
        ":ral_base_cpu_context_impl",
        ":context_util",
        ":common_context",
        ":collective_ops",
        ":init_stream_executor",
        ":ral_context",
        ":ral_cpu_driver",
        ":ral_gpu_driver",
        ":ral_library",
        ":ral_logging",
        "@org_tensorflow//tensorflow/core:lib",
        "@com_google_absl//absl/strings",
        "@curl",
    ] + if_rocm_is_configured([
        "@org_tensorflow//tensorflow/compiler/xla/stream_executor/rocm:rocm_platform",
        "@org_tensorflow//tensorflow/compiler/xla/stream_executor/rocm:rocm_driver",
    ]) + if_cuda_is_configured([
        "@local_config_nccl//:nccl",
        "@org_tensorflow//tensorflow/compiler/xla/stream_executor/cuda:cuda_platform",
        "@local_config_cuda//cuda:cuda_driver",
    ]),
    alwayslink=1,
    visibility=["//visibility:public"],
)

tf_cc_shared_object(
    name="libral_base_context.so",
    linkopts=select({
        "//conditions:default": [
            "-z defs",
            "-Wl,--version-script",  # This line must be directly followed by the version_script.lds file
            "$(location //mlir/ral:context_version_scripts.lds)",
        ],
    }),
    deps=[
        ":ral_base_cpu_context_impl",
        "//mlir/ral:context_version_scripts.lds",
    ] + if_cuda_or_rocm([
        ":ral_base_cuda_context_impl",
    ]))

cc_library(
    name="ral_base_context_lib",
    data=[
        ':libral_base_context.so',
    ],
    srcs=["libral_base_context.so"],
    includes=[
        ".",
    ],
    visibility=["//visibility:public"],
)


###########################################################################################

cc_library(
    name="tf_deps",
    deps=if_internal_serving(
        [
            "@com_google_absl//absl/types:variant",
            "@org_tensorflow//tensorflow/core:lib",
            "@org_tensorflow//tensorflow/core:framework",
            "@org_tensorflow//tensorflow/core:core_cpu",
            "@org_tensorflow//tensorflow/core:stream_executor_no_cuda",
            "@org_tensorflow//tensorflow/stream_executor:stream_executor_impl",
        ],
        [
            "@local_config_tf//:libtensorflow_framework",
            "@local_config_tf//:tf_header_lib",
        ],
    ),
    alwayslink=1,
)

# The targets following are for tf bridge side build. Get rid of all the deps from tf source code, only use some of the `.bzl` utils from tensorflow.

cc_library(
    name="common_context_bridge",
    srcs=[
        "context/common_context_impl.cc",
        "context/common_context_impl_pdll.cc",
    ] + if_cuda_or_rocm([
        "context/common_context_impl_cuda.cc",
        "context/stream_executor_based_impl.cc",
    ]) + if_cuda([
        "context/cuda_impl_sparse.cc",
        "context/quantized_gpu_kernel_impl.cc",
    ]) + if_mkldnn([
        "context/common_context_impl_mkldnn.cc",
        "context/common_context_impl_quantization.cc",
    ]) + if_disc_aarch64([
        "context/common_context_impl_acl.cc",
    ]) + if_blade_gemm([
        "context/stream_executor_expand_impl.cc",
    ]),
    hdrs=[
        "context/common_context_impl.h",
        "context/stream_executor_based_impl.h",
    ] + if_mkldnn([
        "context/common_context_impl_mkldnn.h",
        "context/common_context_impl_quantization.h",
    ]) + if_disc_aarch64([
        "context/common_context_impl_acl.h",
    ]),
    copts=[
        "-fopenmp",
        "-DDISC_BUILD_FROM_TF_BRIDGE"
    ] + if_cuda_or_rocm([
        "-DTAO_RAL_USE_STREAM_EXECUTOR"
    ]) + if_dcu([
        "-DTENSORFLOW_USE_DCU=1"
    ]) + if_rocm_is_configured([
        "-DTENSORFLOW_USE_ROCM=1",
        "-x rocm",
    ]) + if_internal_serving([
        "-DTF_1_12",
        "-DIS_PAI_TF",
    ]),
    linkopts=[
        "-fopenmp",
        "-ldl",
        "-lm"
    ] + if_blade_gemm([
        "-Wl,--start-group",
        "$(location @mkl_static//:lib)/libmkl_intel_ilp64.a",
        "$(location @mkl_static//:lib)/libmkl_gnu_thread.a",
        "$(location @mkl_static//:lib)/libmkl_core.a",
        "-Wl,--end-group",
    ]),
    deps=[
        ":pdll_util",
        ":ral_context",
        ":ral_cpu_driver",
        ":ral_gpu_driver",
        ":ral_logging",
        ":ral_metadata",
        ":context_util",
        ":tf_deps",
    ] + if_cuda_is_configured([
        "@local_config_cuda//cuda:cuda_driver",
        "@local_config_cuda//cuda:cuda_headers",
    ]) + if_rocm_is_configured([
        "@local_config_rocm//rocm:rocm_headers",
    ]) + if_blade_gemm([
        "@blade_gemm//:blade_gemm",
        "@mkl_static//:mkl_static_lib",
    ]) + if_mkldnn([
        ":disc_mkl_bazel",
        ":ral_md5",
    ]),
    data=if_blade_gemm([
        "@mkl_static//:lib",
    ]),
    alwayslink=1,
)

cc_library(
    name="ral_tf_context_bridge",
    srcs=[
        "context/tensorflow/tf_context_impl.cc",
        "context/tensorflow/tf_kernel_impl.cc",
    ],
    hdrs=[
        "context/tensorflow/tf_context_impl.h"
    ],
    copts=[
        "-DDISC_BUILD_FROM_TF_BRIDGE"
    ] + if_rocm_is_configured([
        "-DTENSORFLOW_USE_ROCM=1",
        "-x rocm",
    ]) + if_internal_serving([
        "-DTF_1_12",
        "-DIS_PAI_TF",
    ]),
    deps=[
        ":context_util",
        ":common_context_bridge",
        ":ral_context",
        ":ral_cpu_driver",
        ":ral_gpu_driver",
        ":ral_library",
        "@local_config_cuda//cuda:cuda_headers",
        ":tf_deps",
    ],
    alwayslink=1,
)

cc_library(
    name="ral_bridge",
    deps=[
        ":common_context_bridge",
        ":ral_tf_context_bridge",
    ],
    alwayslink=1,
)
